Skip to content

[PyTorch] Pad V when Q/V head dims differ (MLA) for THD#2629

Merged
cyanguwa merged 3 commits into
NVIDIA:mainfrom
HollowMan6:mla_thd
Jun 5, 2026
Merged

[PyTorch] Pad V when Q/V head dims differ (MLA) for THD#2629
cyanguwa merged 3 commits into
NVIDIA:mainfrom
HollowMan6:mla_thd

Conversation

@HollowMan6

@HollowMan6 HollowMan6 commented Jan 27, 2026

Copy link
Copy Markdown
Member

Description

For MLA, we shall pad V when Q/V head dims differ for THD

Similar to NVIDIA/Megatron-LM#3003

Fixes NVIDIA/Megatron-LM#1698

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • pad V when Q/V head dims differ for THD

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Copilot AI review requested due to automatic review settings January 27, 2026 23:31
@greptile-apps

greptile-apps Bot commented Jan 27, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR fixes a bug in THD-format MLA (e.g. DeepSeek V3) where FlashAttention 2 was entirely blocked for mismatched Q/K and V head dimensions. It introduces zero-padding of V (and optionally Q/K) up to max(head_dim_qk, head_dim_v) before the FA2 call and trims the output back to the original V head dimension afterward, while also tightening the FA2 head-dim validity check to account for the padded dimension.

  • Adds _pad_qkv_head_dim and _trim_output helpers; applies them in the FA2 branch when head_dim_qk != head_dim_v and the backend is FA2 (guards correctly skip FA3/FA4 which support MLA natively).
  • Removes the blanket FA2-disable guard for mismatched head dims in utils.py and replaces it with a fa2_padded_head_dim-based validity check that also enforces a >192 restriction on older architectures.

Confidence Score: 3/5

The non-FP8 THD MLA path works correctly, but the removed FA2 guard combined with the Float8TensorStorage exclusion from padding leaves the Float8 + MLA + FA2 combination unprotected — FA2 would receive tensors with mismatched head dimensions.

The removed blanket FA2 guard for head_dim_qk != head_dim_v in utils.py is not fully compensated by the padding logic in dot_product_attention.py, which skips Float8TensorStorage inputs. If a Float8 MLA configuration reaches FA2, it will call FA2 with unpadded mismatched head dims, causing a crash or incorrect results. The same guard previously covered this case safely.

Both changed files interact to create the regression: utils.py removes the guard that blocked FA2 for all mismatched-head-dim cases, while dot_product_attention.py adds padding but excludes Float8 tensors.

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py Adds _pad_qkv_head_dim/_trim_output helpers and applies FA2 V-padding for MLA. The Float8TensorStorage exclusion from padding combined with the removed FA2 guard creates a regression where Float8 + MLA + FA2 reaches FA2 with unpadded mismatched head dims.
transformer_engine/pytorch/attention/dot_product_attention/utils.py Removes FA2 MLA blanket guard and replaces it with a padded-head-dim validity check; adds >192 restriction for older architectures. Logic is correct for the non-Float8 path; the Float8 regression stems from the interaction with dot_product_attention.py.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["DotProductAttention.forward - MLA head_dim_qk != head_dim_v"] --> B{use_flash_attention?}
    B -- No --> C[FusedAttention or Unfused]
    B -- Yes --> D{backend == FA2 version?}
    D -- No --> E[FA3/FA4 support MLA natively - no padding needed]
    D -- Yes --> F{value is Float8TensorStorage?}
    F -- Yes --> G["Skip padding - FA2 receives mismatched head dims - potential crash"]
    F -- No --> H[_pad_qkv_head_dim - pad V to head_dim_qk]
    H --> I[flash_attention with padded Q/K/V]
    I --> J{orig_qk_dim > orig_v_dim?}
    J -- Yes --> K[_trim_output - slice back to orig_head_dim_v]
    J -- No --> L[Return attn_out as-is]
    K --> M[Correct output]
    L --> M
Loading

Comments Outside Diff (1)

  1. transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py, line 1657-1668 (link)

    Float8 MLA with FA2 receives unpadded mismatched tensors

    The old guard in utils.py unconditionally disabled FA2 for every head_dim_qk != head_dim_v case, including Float8. Now that guard is gone, FA2 can be selected for Float8 MLA. But the padding block here excludes Float8TensorStorage, so FA2 is called with the original mismatched dims — a crash or silent corruption that the old guard prevented.

    Either restore the disabled-FA2 guard in utils.py specifically for the Float8 + mismatched-head-dim case, or drop the not isinstance(value_layer, Float8TensorStorage) exclusion here so Float8 tensors get padded along the same path (if F.pad supports Float8TensorStorage).

Reviews (9): Last reviewed commit: "Support when v is larger than qk" | Re-trigger Greptile

@greptile-apps greptile-apps Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment thread tests/pytorch/attention/test_attention.py Outdated

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds support for Multi-head Latent Attention (MLA) with mismatched Q/V head dimensions in the THD (Total-Hidden-Dimension) format. When the value tensor has a smaller head dimension than the query/key tensors, the code pads the value tensor to match the Q/K head dimension, runs the attention operation, and then trims the output back to the original V dimension.

Changes:

  • Added padding logic for V tensor when head dimensions differ in THD format
  • Implemented trimming function to restore correct output dimensions after attention
  • Added test case for THD attention with mismatched Q/V head dimensions

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py Implements padding of V tensor before attention and trimming of output after attention for THD format with mismatched Q/V head dimensions
tests/pytorch/attention/test_attention.py Adds test case to verify THD attention works with different Q/V head dimensions

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread tests/pytorch/attention/test_attention.py Outdated

@greptile-apps greptile-apps Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@greptile-apps greptile-apps Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@ptrendx ptrendx added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Apr 9, 2026
@cyanguwa cyanguwa requested a review from vcherepanov-nv April 22, 2026 22:21
@cyanguwa

Copy link
Copy Markdown
Collaborator

This change should only be required by the FlashAttention backend. The other two backends FusedAttention and UnfusedDPA do support MLA (head_dim_qk != head_dim_v). I'd propose a few changes:

@vcherepanov-nv, could you help push this PR through the finish line? Thanks!

@HollowMan6

Copy link
Copy Markdown
Member Author

Thank you @cyanguwa, I just cleaned up the PR and also follow your requirements. Please let me know what you think @vcherepanov-nv.

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread transformer_engine/pytorch/attention/dot_product_attention/utils.py
@cyanguwa

cyanguwa commented Jun 1, 2026

Copy link
Copy Markdown
Collaborator

/te-ci pytorch L0

@HollowMan6 HollowMan6 requested a review from cyanguwa June 1, 2026 17:23
@cyanguwa

cyanguwa commented Jun 3, 2026

Copy link
Copy Markdown
Collaborator

/te-ci pytorch L0

@cyanguwa cyanguwa left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HollowMan6, could you help fix the failed tests please? Sorry, it's an oversight on my side too. Right now, _pad_value_layer and _trim_output both assume that V has a shorter head_dim than Q/K, but it could happen the other way as well.

// failed tests: "mla_1_0", "mla_1_1"

"mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128),
"mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128),

// failed error:
https://github.com/Dao-AILab/flash-attention/blob/d80a77103021c4e980f8cbbf85774f6a19e6474a/csrc/flash_attn/flash_api.cpp#L418

I wonder if we can make the pad function look something like this:

def _pad_qkv_head_dim(query_layer, key_layer, value_layer):
return new_q, new_k, new_v, orig_head_dim_qk, orig_head_dim_v

Also, only call _trim_output on padded_head_dim_v > orig_head_dim_v; otherwise, a no op.

Signed-off-by: Hollow Man <hollowman@opensuse.org>
Signed-off-by: Hollow Man <hollowman@opensuse.org>
Signed-off-by: Hollow Man <hollowman@opensuse.org>
@HollowMan6

Copy link
Copy Markdown
Member Author

Thank you for pointing this out @cyanguwa, originally I didn't handle this v > qk as this is not a practice for MLA, but since test cases cover this, I have just pushed the changes accordingly.

@cyanguwa

cyanguwa commented Jun 4, 2026

Copy link
Copy Markdown
Collaborator

/te-ci pytorch L0

@cyanguwa cyanguwa merged commit 8a5af97 into NVIDIA:main Jun 5, 2026
20 of 25 checks passed
@HollowMan6 HollowMan6 deleted the mla_thd branch June 5, 2026 18:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. org-contribution

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG]DotProductAttention:Disabling FlashAttention as it does not support MLA

5 participants